package logger

import (
	"context"
	"errors"
	"fmt"
	"github.com/go-kratos/kratos/v2/log"
	"gorm.io/gorm"
	gormlogger "gorm.io/gorm/logger"
	"time"
)

type logger struct {
	Config
	Helper *log.Helper
}

type Config struct {
	SlowThreshold             time.Duration
	IgnoreRecordNotFoundError bool
	ParameterizedQueries      bool
}

type Option func(c *logger) error

func WithConfig(c Config) Option {
	return func(l *logger) error {
		l.Config = c
		return nil
	}
}

func New(l log.Logger, options ...Option) gormlogger.Interface {
	lg := logger{
		Config: Config{
			SlowThreshold:             200 * time.Millisecond,
			IgnoreRecordNotFoundError: false,
		},
		Helper: log.NewHelper(log.With(l, "module", "gorm")),
	}

	for _, option := range options {
		if err := option(&lg); err != nil {
			return nil
		}
	}
	return lg
}

func (l logger) LogMode(_ gormlogger.LogLevel) gormlogger.Interface {
	return logger{
		Config: l.Config,
		Helper: l.Helper,
	}
}

func (l logger) Info(ctx context.Context, s string, i ...interface{}) {
	l.Helper.WithContext(ctx).Debugf(s, i...)
}

func (l logger) Warn(ctx context.Context, s string, i ...interface{}) {
	l.Helper.WithContext(ctx).Warnf(s, i...)
}

func (l logger) Error(ctx context.Context, s string, i ...interface{}) {
	l.Helper.WithContext(ctx).Errorf(s, i...)
}

func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
	elapsed := time.Since(begin)
	switch {
	case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, gorm.ErrRecordNotFound)):
		sql, rows := fc()
		l.Helper.WithContext(ctx).Errorw("error", err.Error(), "elapsed", float64(elapsed.Nanoseconds())/1e6, "rows", rows, "sql", sql)
	case l.SlowThreshold != 0 && elapsed > l.SlowThreshold:
		sql, rows := fc()
		slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
		l.Helper.WithContext(ctx).Warnw("elapsed", float64(elapsed.Nanoseconds())/1e6, "rows", rows, "sql", sql, "msg", slowLog)
	default:
		sql, rows := fc()
		l.Helper.WithContext(ctx).Infow("elapsed", float64(elapsed.Nanoseconds())/1e6, "rows", rows, "sql", sql)
	}
}

// ParamsFilter Trace print sql message
func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
	if l.Config.ParameterizedQueries {
		return sql, nil
	}
	return sql, params
}
